package logstorage

import (
	"fmt"
	"sort"
	"strings"
	"sync"
	"sync/atomic"
	"unsafe"

	"github.com/VictoriaMetrics/VictoriaMetrics/lib/atomicutil"
	"github.com/VictoriaMetrics/VictoriaMetrics/lib/bytesutil"
	"github.com/VictoriaMetrics/VictoriaMetrics/lib/encoding"
	"github.com/VictoriaMetrics/VictoriaMetrics/lib/logger"
	"github.com/VictoriaMetrics/VictoriaMetrics/lib/memory"
	"github.com/VictoriaMetrics/VictoriaMetrics/lib/slicesutil"

	"github.com/VictoriaMetrics/VictoriaLogs/lib/prefixfilter"
)

// pipeRunningStats processes '| running_stats ...' queries.
//
// See https://docs.victoriametrics.com/victorialogs/logsql/#running_stats-pipe
type pipeRunningStats struct {
	// if isTotal is set, then total stats must be calculated instead of running stats (aka `total_stats` pipe).
	isTotal bool

	// byFields contains field names with optional buckets from 'by(...)' clause.
	byFields []string

	// funcs contains stats functions to execute.
	funcs []pipeRunningStatsFunc
}

type pipeRunningStatsFunc struct {
	// f is stats function to execute
	f runningStatsFunc

	// resultName is the name of the output generated by f
	resultName string
}

type runningStatsFunc interface {
	// String returns string representation of runningStatsFunc
	String() string

	// updateNeededFields must update pf with the fields needed for calculating the given running stats
	updateNeededFields(pf *prefixfilter.Filter)

	// newRunningStatsProcessor must create new runningStatsProcessor for calculating stats for the given runningStatsFunc
	newRunningStatsProcessor() runningStatsProcessor
}

// runningStatsProcessor must process stats for some runningStatsFunc.
//
// All the runningStatsProcessor methods are called from a single goroutine at a time,
// so there is no need in the internal synchronization.
type runningStatsProcessor interface {
	// updateRunningStats must update stats according to the given row.
	updateRunningStats(sf runningStatsFunc, row []Field)

	// getRunningStats must return the current value for the running stats.
	getRunningStats() string
}

func (ps *pipeRunningStats) String() string {
	s := "running_stats"
	if ps.isTotal {
		s = "total_stats"
	}

	if len(ps.byFields) > 0 {
		s += " by (" + fieldNamesString(ps.byFields) + ")"
	}

	funcs := ps.funcs
	if len(funcs) == 0 {
		logger.Panicf("BUG: pipeRunningStats must contain at least a single runningStatsFunc")
	}
	a := make([]string, len(funcs))

	for i, f := range funcs {
		a[i] = fmt.Sprintf("%s as %s", f.f.String(), quoteTokenIfNeeded(f.resultName))
	}
	s += " " + strings.Join(a, ", ")
	return s
}

func (ps *pipeRunningStats) splitToRemoteAndLocal(_ int64) (pipe, []pipe) {
	return nil, []pipe{ps}
}

func (ps *pipeRunningStats) canLiveTail() bool {
	return false
}

func (ps *pipeRunningStats) canReturnLastNResults() bool {
	return false
}

func (ps *pipeRunningStats) updateNeededFields(pf *prefixfilter.Filter) {
	pfOrig := pf.Clone()

	for _, f := range ps.funcs {
		pf.AddDenyFilter(f.resultName)
		if pfOrig.MatchString(f.resultName) {
			f.f.updateNeededFields(pf)
		}
	}

	// byFields are needed unconditionally, since the output depends on them.
	for _, bf := range ps.byFields {
		pf.AddAllowFilter(bf)
	}
}

func (ps *pipeRunningStats) hasFilterInWithQuery() bool {
	return false
}

func (ps *pipeRunningStats) initFilterInValues(_ *inValuesCache, _ getFieldValuesFunc, _ bool) (pipe, error) {
	return ps, nil
}

func (ps *pipeRunningStats) visitSubqueries(visitFunc func(q *Query)) {
	// nothing to do
}

func (ps *pipeRunningStats) newPipeProcessor(_ int, stopCh <-chan struct{}, cancel func(), ppNext pipeProcessor) pipeProcessor {
	maxStateSize := int64(float64(memory.Allowed()) * 0.4)

	psp := &pipeRunningStatsProcessor{
		ps:     ps,
		stopCh: stopCh,
		cancel: cancel,
		ppNext: ppNext,

		maxStateSize: maxStateSize,
	}

	psp.stateSizeBudget.Store(maxStateSize)

	return psp
}

type pipeRunningStatsProcessor struct {
	ps     *pipeRunningStats
	stopCh <-chan struct{}
	cancel func()
	ppNext pipeProcessor

	shards atomicutil.Slice[pipeRunningStatsProcessorShard]

	maxStateSize    int64
	stateSizeBudget atomic.Int64
}

type pipeRunningStatsProcessorShard struct {
	// rows tracks all the rows collected by the shard.
	rows [][]Field

	columnValues [][]string

	stateSizeBudget int
}

func (shard *pipeRunningStatsProcessorShard) writeBlock(br *blockResult) {
	cs := br.getColumns()

	columnValues := slicesutil.SetLength(shard.columnValues, len(cs))
	for i, c := range cs {
		columnValues[i] = c.getValues(br)
	}
	shard.columnValues = columnValues

	for rowIdx := 0; rowIdx < br.rowsLen; rowIdx++ {
		fields := make([]Field, len(cs))
		shard.stateSizeBudget -= int(unsafe.Sizeof(fields[0])) * len(fields)

		for j, c := range cs {
			v := columnValues[j][rowIdx]
			fields[j] = Field{
				Name:  strings.Clone(c.name),
				Value: strings.Clone(v),
			}
			shard.stateSizeBudget -= len(c.name) + len(v)
		}

		shard.rows = append(shard.rows, fields)
		shard.stateSizeBudget -= int(unsafe.Sizeof(fields))
	}
}

func (psp *pipeRunningStatsProcessor) writeBlock(workerID uint, br *blockResult) {
	if br.rowsLen == 0 {
		return
	}

	shard := psp.shards.Get(workerID)

	for shard.stateSizeBudget < 0 {
		// steal some budget for the state size from the global budget.
		remaining := psp.stateSizeBudget.Add(-stateSizeBudgetChunk)
		if remaining < 0 {
			// The state size is too big. Stop processing data in order to avoid OOM crash.
			if remaining+stateSizeBudgetChunk >= 0 {
				// Notify worker goroutines to stop calling writeBlock() in order to save CPU time.
				psp.cancel()
			}
			return
		}
		shard.stateSizeBudget += stateSizeBudgetChunk
	}

	shard.writeBlock(br)
}

func (psp *pipeRunningStatsProcessor) flush() error {
	if n := psp.stateSizeBudget.Load(); n <= 0 {
		return fmt.Errorf("cannot calculate [%s], since it requires more than %dMB of memory", psp.ps.String(), psp.maxStateSize/(1<<20))
	}

	getKeyForRow := func(row []Field) string {
		var key []byte
		for _, bf := range psp.ps.byFields {
			v := getFieldValueByName(row, bf)
			key = encoding.MarshalBytes(key, bytesutil.ToUnsafeBytes(v))
		}
		return string(key)
	}

	type rowWithTimestamp struct {
		timestamp string
		fields    []Field
	}

	m := make(map[string][]rowWithTimestamp)
	shards := psp.shards.All()
	for _, shard := range shards {
		for _, row := range shard.rows {
			if needStop(psp.stopCh) {
				return nil
			}

			key := getKeyForRow(row)
			timestamp := getFieldValueByName(row, "_time")
			m[key] = append(m[key], rowWithTimestamp{
				timestamp: timestamp,
				fields:    row,
			})
		}
	}

	// Sort output be keys
	keys := make([]string, 0, len(m))
	for key := range m {
		keys = append(keys, key)
	}
	sort.Strings(keys)

	// Write output
	wctx := &pipeRunningStatsWriter{
		ppNext: psp.ppNext,
	}

	funcs := psp.ps.funcs
	for _, key := range keys {
		rows := m[key]
		sort.Slice(rows, func(i, j int) bool {
			return rows[i].timestamp < rows[j].timestamp
		})

		if needStop(psp.stopCh) {
			return nil
		}

		sps := make([]runningStatsProcessor, len(funcs))
		for i, f := range funcs {
			sps[i] = f.f.newRunningStatsProcessor()
		}

		if psp.ps.isTotal {
			for _, row := range rows {
				for i, sp := range sps {
					sp.updateRunningStats(funcs[i].f, row.fields)
				}
			}
		}

		for _, row := range rows {
			fields := make([]Field, 0, len(row.fields)+len(sps))
			fields = append(fields, row.fields...)
			for i, sp := range sps {
				f := funcs[i]
				if !psp.ps.isTotal {
					sp.updateRunningStats(f.f, row.fields)
				}
				result := sp.getRunningStats()
				fields = append(fields, Field{
					Name:  f.resultName,
					Value: result,
				})
			}
			wctx.writeRow(fields)
		}
	}

	wctx.flush()

	return nil
}

type pipeRunningStatsWriter struct {
	ppNext pipeProcessor

	rcs []resultColumn
	br  blockResult

	valuesLen int
	rowsCount int
}

func (wctx *pipeRunningStatsWriter) writeRow(row []Field) {
	rcs := wctx.rcs

	areEqualColumns := len(rcs) == len(row)
	if areEqualColumns {
		for i, f := range row {
			if rcs[i].name != f.Name {
				areEqualColumns = false
				break
			}
		}
	}
	if !areEqualColumns {
		// send the current block to ppNext and construct a block with new set of columns
		wctx.flush()

		rcs = wctx.rcs[:0]
		for _, f := range row {
			rcs = appendResultColumnWithName(rcs, f.Name)
		}
		wctx.rcs = rcs
	}

	for i, f := range row {
		v := f.Value
		rcs[i].addValue(v)
		wctx.valuesLen += len(v)
	}

	wctx.rowsCount++
	// The 64_000 limit provides the best performance results.
	if wctx.valuesLen >= 64_000 {
		wctx.flush()
	}
}

func (wctx *pipeRunningStatsWriter) flush() {
	rcs := wctx.rcs

	wctx.valuesLen = 0

	// Flush rcs to ppNext
	br := &wctx.br
	br.setResultColumns(rcs, wctx.rowsCount)
	wctx.rowsCount = 0
	wctx.ppNext.writeBlock(0, br)
	br.reset()
	for i := range rcs {
		rcs[i].resetValues()
	}
}

func parsePipeRunningStats(lex *lexer) (pipe, error) {
	if !lex.isKeyword("running_stats") {
		return nil, fmt.Errorf("expecting `running_stats`; got %q", lex.token)
	}
	lex.nextToken()

	return parsePipeRunningStatsExt(lex, "running_stats")
}

func parsePipeRunningStatsExt(lex *lexer, pipeName string) (pipe, error) {
	var ps pipeRunningStats
	if pipeName == "total_stats" {
		ps.isTotal = true
	}

	if lex.isKeyword("by", "(") {
		if lex.isKeyword("by") {
			lex.nextToken()
		}
		bfs, err := parseFieldNamesInParens(lex)
		if err != nil {
			return nil, fmt.Errorf("cannot parse 'by' clause: %w", err)
		}
		ps.byFields = bfs
	}

	seenResultNames := make(map[string]runningStatsFunc)

	var funcs []pipeRunningStatsFunc
	for {
		var f pipeRunningStatsFunc

		sf, err := parseRunningStatsFunc(lex)
		if err != nil {
			return nil, err
		}
		f.f = sf

		resultName := ""
		if lex.isKeyword(",", "|", ")", "") {
			resultName = sf.String()
		} else {
			if lex.isKeyword("as") {
				lex.nextToken()
			}
			fieldName, err := parseFieldName(lex)
			if err != nil {
				return nil, fmt.Errorf("cannot parse result name for [%s]: %w", sf, err)
			}
			resultName = fieldName
		}
		if sfPrev := seenResultNames[resultName]; sfPrev != nil {
			return nil, fmt.Errorf("cannot use identical result name %q for [%s] and [%s]", resultName, sfPrev, sf)
		}
		seenResultNames[resultName] = sf
		f.resultName = resultName

		funcs = append(funcs, f)

		if lex.isKeyword("|", ")", "") {
			ps.funcs = funcs
			return &ps, nil
		}
		if !lex.isKeyword(",") {
			return nil, fmt.Errorf("unexpected token %q after [%s]; want ',', '|' or ')'", lex.token, sf)
		}
		lex.nextToken()
	}
}

func parseRunningStatsFunc(lex *lexer) (runningStatsFunc, error) {
	sps := getRunningStatsFuncParsers()
	for funcName, parserFunc := range sps {
		if !lex.isKeyword(funcName) {
			continue
		}
		sf, err := parserFunc(lex)
		if err != nil {
			return nil, fmt.Errorf("cannot parse %q func: %w", funcName, err)
		}
		return sf, nil
	}
	return nil, fmt.Errorf("unknown stats func %q", lex.token)
}

var runningStatsFuncParsers map[string]runningStatsFuncParser
var runningStatsFuncParsersOnce sync.Once

type runningStatsFuncParser func(lex *lexer) (runningStatsFunc, error)

func getRunningStatsFuncParsers() map[string]runningStatsFuncParser {
	runningStatsFuncParsersOnce.Do(initRunningStatsFuncParsers)
	return runningStatsFuncParsers
}

func initRunningStatsFuncParsers() {
	runningStatsFuncParsers = map[string]runningStatsFuncParser{
		"count": parseRunningStatsCount,
		"max":   parseRunningStatsMax,
		"min":   parseRunningStatsMin,
		"sum":   parseRunningStatsSum,
	}
}
